import itertools
import math
import random
import re
import statistics
import sys
from collections import Counter
from datetime import datetime
from operator import itemgetter
from random import sample
from typing import List, Tuple

import numpy as np
import pandas as pd

import nltk
from imblearn.over_sampling import ADASYN
from matplotlib import pyplot as plt
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.stem.snowball import GermanStemmer
from nltk.tokenize import WordPunctTokenizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
from sklearn.naive_bayes import MultinomialNB


def preprocessfunction(data):
    """Text preprocessing function"""
    data_processed = []
    for text_string in data:
        text_string = str(text_string)
        # remove punctuation
        text_string = re.sub('\W+', ' ', text_string)
        # tokenize data
        tokenizer = WordPunctTokenizer()
        text_string = tokenizer.tokenize(text_string)
        # stem data
        stemmer = GermanStemmer()
        text_string = [stemmer.stem(token) for token in text_string]
        # remove stop words
        stop_tokens = set(stopwords.words('german'))
        text_string = [token for token in text_string if len(
            token.lower()) > 1 and (token.lower() not in stop_tokens)]
        # remove tokens with fewer than two characters
        text_string = [
            token for token in text_string if len(token.lower()) > 2]
        data_processed.append(text_string)
    for i in range(len(data_processed)):
        data_processed[i] = ' '.join(data_processed[i])
    data = data_processed
    return data


def chunks(lst, amount):
    """Split data into n-folds for cross-validation"""
    n = math.ceil(len(lst) / amount)
    for i in range(0, len(lst), n):
        yield lst[i:i + n]


def unzip(lst):
    return ([x for x, _ in lst], [y for _, y in lst])


def split_datafunction(data, labels, max_splits):
    """Split labels and data in n lists of the same size"""
    assert(len(data) == len(labels))
    split_length = math.trunc(len(data) / max_splits)
    data_list = []
    data_samples = []
    labeled_data: List[Tuple[str, int]] = list(zip(data, labels))
    random.shuffle(labeled_data)
    labeled_data_chunks: List[List[Tuple[str, int]]] = list(
        chunks(labeled_data, max_splits))
    samples_list = []
    for index, test_list in enumerate(labeled_data_chunks):
        test_data, test_labels = unzip(test_list)
        train_lists = [train for train_index, train in enumerate(
            labeled_data_chunks) if index != train_index]
        train_list = list(itertools.chain(*train_lists))
        train_data, train_labels = unzip(train_list)
        samples_list.append([train_data, test_data, train_labels, test_labels])
    return samples_list


def vectorize_data(train_data, test_data, data_untagged):
    """Vectorize data to turn textual into numerical data"""
    bow_vectorizer = CountVectorizer(ngram_range=(1, 3))
    train_data = bow_vectorizer.fit_transform(train_data)
    test_data = bow_vectorizer.transform(test_data)
    data_untagged = bow_vectorizer.transform(data_untagged)
    return train_data, test_data, data_untagged


def main():
    # define filenames
    filenames = ["pp", "drs", "bulletin", "sz", "spiegel", "taz", "zeit"]
    # define N for cross-validation
    max_splits = 10
    # define cut-off value for classifiers to be included in ensemble
    acc_cut_off = 0.9
    # define range for forest size
    min_n = 1
    max_n = 300
    print("*****************")
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("Time initiated: ", current_time)
    print("*****************")
    print("Parameters:")
    print(f"Max Splits: {max_splits}")
    print(f"Accuracy cut-off value: {acc_cut_off}")
    print(f"Min n: {min_n}")
    print(f"Max n: {max_n}")
    print("*****************")
    for filename in filenames:
        print(f"Filename: {filename}")
        print("*****************")
        # read data and seperate tagged and untagged
        df = pd.read_csv(f"training_set/{filename}_tagged_human.csv")
        df_tagged = df[df.tagged == 1].copy()
        df_untagged = df[df.tagged == 0].copy()
        print(f"{str(len(df))} documents imported. {str(len(df_tagged))} are tagged and {str(len(df_untagged))} are untagged.")
        # map score to either 1, -1, or 0
        df_tagged.loc[:, 'score'] = df_tagged['score'].apply(
            lambda s: -1 if s < 0 else (1 if s > 0 else 0))
        # get feature data
        data = df_tagged['text']
        data_untagged = df_untagged['text']
        len_untagged = len(data_untagged)
        # get labels
        labels = df_tagged['score']
        print(df_tagged.groupby(['score']).count())
        # split data into n subsets for cross-validation
        sample_list = split_datafunction(data, labels, max_splits=max_splits)
        sample_index = -1
        # iterate over every subset
        results = []
        for sample_list_no in sample_list:
            sample_index += 1
            print(f"Iteration: {sample_index + 1}")
            # retrieve data
            train_data = sample_list_no[0]
            test_data = sample_list_no[1]
            train_labels = sample_list_no[2]
            test_labels = sample_list_no[3]
            # preprocess feature data
            train_data = preprocessfunction(train_data)
            test_data = preprocessfunction(test_data)
            data_untagged = preprocessfunction(data_untagged)
            train_data, test_data, data_untagged = vectorize_data(
                train_data, test_data, data_untagged)
            # over-sample data to reduce class imbalance
            try:
                train_data, train_labels = ADASYN().fit_resample(train_data, train_labels)
                print("Sample imbalance: ADASYN applied.")
            except ValueError:
                print("Sample balanced.")
            # optimize and train classifiers, and calculate weights from accuracy
            # train Random Forest
            def best_random_forest() -> int:
                accuracies = {}
                for i in range(min_n, max_n):
                    random_forest = RandomForestClassifier(
                        n_estimators=i, random_state=42)
                    random_forest.fit(train_data, train_labels)
                    accuracies[i] = random_forest.score(test_data, test_labels)
                return max(accuracies, key=lambda x: accuracies[x])
            n = best_random_forest()
            random_forest = RandomForestClassifier(
                n_estimators=n, random_state=42)
            random_forest.fit(train_data, train_labels)
            random_forest_weight = random_forest.score(test_data, test_labels)
            print(
                f"Random Forest trained with n={n} at: {random_forest_weight}")
            # train logistic regression
            logistic_regression = LogisticRegression(
                multi_class='auto', solver='lbfgs', max_iter=2000)
            logistic_regression.fit(train_data, train_labels)
            logistic_regression_weight = logistic_regression.score(
                test_data, test_labels)
            print(
                f"Logistic Regression trained at: {logistic_regression_weight}")
            # train bayes
            bayes = MultinomialNB()
            bayes.fit(train_data, train_labels)
            bayes_weight = bayes.score(test_data, test_labels)
            print(f"Bayes trained at: {bayes_weight}")
            print("Trained classifiers.")
            # iterate over test subset
            test_predictions = []
            for i in range(len(test_labels)):
                # calculate class probabilities of classifications
                random_forest_proba = random_forest.predict_proba(test_data[i])
                logistic_regression_proba = logistic_regression.predict_proba(
                    test_data[i])
                bayes_proba = bayes.predict_proba(test_data[i])
                p_neg = 0.0
                p_zero = 0.0
                p_pos = 0.0
                # exclude classifiers below the cut off point
                if random_forest_weight < acc_cut_off:
                    random_forest_exclude = 0
                else:
                    random_forest_exclude = 1
                if logistic_regression_weight < acc_cut_off:
                    logistic_regression_exclude = 0
                else:
                    logistic_regression_exclude = 1
                if bayes_weight < acc_cut_off:
                    bayes_exclude = 0
                else:
                    bayes_exclude = 1
                length_test_predictions = len(test_predictions)
                p_neg = random_forest_proba[0][0] * random_forest_weight * random_forest_exclude + logistic_regression_proba[0][0] * \
                    logistic_regression_weight * logistic_regression_exclude + \
                    bayes_proba[0][0] * bayes_weight * bayes_exclude
                p_zero = random_forest_proba[0][1] * random_forest_weight * random_forest_exclude + logistic_regression_proba[0][1] * \
                    logistic_regression_weight * logistic_regression_exclude + \
                    bayes_proba[0][1] * bayes_weight * bayes_exclude
                p_pos = random_forest_proba[0][2] * random_forest_weight * random_forest_exclude + logistic_regression_proba[0][2] * \
                    logistic_regression_weight * logistic_regression_exclude + \
                    bayes_proba[0][2] * bayes_weight * bayes_exclude
                # determine class
                if (p_neg > p_zero) and (p_neg > p_pos):
                    test_predictions.append(-1)
                elif (p_zero > p_neg) and (p_zero > p_pos):
                    test_predictions.append(0)
                elif (p_pos > p_neg) and (p_pos > p_zero):
                    test_predictions.append(1)
                # use best classifier if all below cut-off value
                else:
                    weights_list = [[random_forest_weight, "random_forest"], [
                        logistic_regression_weight, "logistic_regression"], [bayes_weight, "bayes"]]
                    weights_list = sorted(weights_list, key=itemgetter(0))
                    best_classifier = weights_list[-1][1]
                    if best_classifier == "random_forest":
                        test_predictions.append(
                            random_forest.predict(test_data[i])[0])
                    if best_classifier == "logistic_regression":
                        test_predictions.append(
                            logistic_regression.predict(test_data[i])[0])
                    if best_classifier == "bayes":
                        test_predictions.append(bayes.predict(test_data[i])[0])
            # calculate metrics of predicted classes and append to overall results
            accuracy = accuracy_score(
                test_labels, test_predictions, normalize=True)
            f1score = f1_score(test_labels, test_predictions,
                               average='weighted', labels=np.unique(test_predictions))
            result = [accuracy, f1score, n, sample_index]
            results.append(result)
            print(f"Accuracy: {accuracy}")
            print(f"F1-Score: {f1score}")
            print("******************************")
        # sort results by accuracy and print parameters
        results = sorted(results, key=itemgetter(0))
        # print cross validation mean metrics
        final_acc = statistics.mean([item[0] for item in results])
        final_f1 = statistics.mean([item[1] for item in results])
        print(f"Final Accuracy: {final_acc}")
        print(f"Final F1-Score: {final_f1}")
        now = datetime.now()
        current_time = now.strftime("%H:%M:%S")
        print("Time completed: ", current_time)
        print("*****************")
        # make final predictions with optimal parameters
        # retrieve optimal parameters
        n = results[-1][2]
        sample_index = results[-1][3]
        # retrieve data
        data_untagged = df_untagged['text']
        train_data = sample_list[sample_index][0]
        test_data = sample_list[sample_index][1]
        train_labels = sample_list[sample_index][2]
        test_labels = sample_list[sample_index][3]
        # preprocess feature data
        train_data = preprocessfunction(train_data)
        test_data = preprocessfunction(test_data)
        data_untagged = preprocessfunction(data_untagged)
        train_data, test_data, data_untagged = vectorize_data(
            train_data, test_data, data_untagged)
        try:
            train_data, train_labels = ADASYN().fit_resample(train_data, train_labels)
            print("Sample imbalance: ADASYN applied.")
        except ValueError:
            print("Sample balanced.")
        # optimize and train classifiers, and calculate weights from accuracy
        # train Random Forest
        random_forest = RandomForestClassifier(n_estimators=n, random_state=42)
        random_forest.fit(train_data, train_labels)
        random_forest_weight = random_forest.score(test_data, test_labels)
        print(f"Random Forest trained with n={n} at: {random_forest_weight}")
        # train logistic regression
        logistic_regression = LogisticRegression(
            multi_class='auto', solver='lbfgs', max_iter=2000)
        logistic_regression.fit(train_data, train_labels)
        logistic_regression_weight = logistic_regression.score(
            test_data, test_labels)
        print(f"Logistic Regression trained at: {logistic_regression_weight}")
        # train bayes
        bayes = MultinomialNB()
        bayes.fit(train_data, train_labels)
        bayes_weight = bayes.score(test_data, test_labels)
        print(f"Bayes trained at: {bayes_weight}")
        # iterate over test subset
        predictions = []
        for i in range(len_untagged):
            # calculate class probabilities of classifications
            random_forest_proba = random_forest.predict_proba(data_untagged[i])
            logistic_regression_proba = logistic_regression.predict_proba(
                data_untagged[i])
            bayes_proba = bayes.predict_proba(data_untagged[i])
            p_neg = 0.0
            p_zero = 0.0
            p_pos = 0.0
            # exclude classifiers below the cut off point
            if random_forest_weight < acc_cut_off:
                random_forest_exclude = 0
            else:
                random_forest_exclude = 1
            if logistic_regression_weight < acc_cut_off:
                logistic_regression_exclude = 0
            else:
                logistic_regression_exclude = 1
            if bayes_weight < acc_cut_off:
                bayes_exclude = 0
            else:
                bayes_exclude = 1
            length_predictions = len(predictions)
            p_neg = random_forest_proba[0][0] * random_forest_weight * random_forest_exclude + logistic_regression_proba[0][0] * \
                logistic_regression_weight * logistic_regression_exclude + \
                bayes_proba[0][0] * bayes_weight * bayes_exclude
            p_zero = random_forest_proba[0][1] * random_forest_weight * random_forest_exclude + logistic_regression_proba[0][1] * \
                logistic_regression_weight * logistic_regression_exclude + \
                bayes_proba[0][1] * bayes_weight * bayes_exclude
            p_pos = random_forest_proba[0][2] * random_forest_weight * random_forest_exclude + logistic_regression_proba[0][2] * \
                logistic_regression_weight * logistic_regression_exclude + \
                bayes_proba[0][2] * bayes_weight * bayes_exclude
            # determine class
            if (p_neg > p_zero) and (p_neg > p_pos):
                predictions.append(-1 * (p_neg/(1*random_forest_exclude +
                                                1*logistic_regression_exclude + 1*bayes_exclude)))
            elif (p_zero > p_neg) and (p_zero > p_pos):
                predictions.append(0)
            elif (p_pos > p_neg) and (p_pos > p_zero):
                predictions.append(
                    p_pos/(1*random_forest_exclude + 1*logistic_regression_exclude + 1*bayes_exclude))
            # use best classifier if all below cut-off value
            else:
                weights_list = [[random_forest_weight, "random_forest"], [
                    logistic_regression_weight, "logistic_regression"], [bayes_weight, "bayes"]]
                weights_list = sorted(weights_list, key=itemgetter(0))
                best_classifier = weights_list[-1][1]
                if best_classifier == "random_forest":
                    prediction_temp = random_forest.predict(
                        data_untagged[i])[0]
                    if prediction_temp == -1:
                        proba_temp = -1 * \
                            random_forest.predict_proba(data_untagged[i])[0][0]
                    elif prediction_temp == 1:
                        proba_temp = random_forest.predict_proba(data_untagged[i])[
                            0][2]
                    else:
                        proba_temp = 0
                    predictions.append(proba_temp)
                if best_classifier == "logistic_regression":
                    prediction_temp = logistic_regression.predict(data_untagged[i])[
                        0]
                    if prediction_temp == -1:
                        proba_temp = -1 * \
                            logistic_regression.predict_proba(
                                data_untagged[i])[0][0]
                    elif prediction_temp == 1:
                        proba_temp = logistic_regression.predict_proba(data_untagged[i])[
                            0][2]
                    else:
                        proba_temp = 0
                    predictions.append(proba_temp)
                if best_classifier == "bayes":
                    prediction_temp = bayes.predict(data_untagged[i])[0]
                    if prediction_temp == -1:
                        proba_temp = -1 * \
                            bayes.predict_proba(data_untagged[i])[0][0]
                    elif prediction_temp == 1:
                        proba_temp = bayes.predict_proba(
                            data_untagged[i])[0][2]
                    else:
                        proba_temp = 0
                    predictions.append(proba_temp)
            if length_predictions == len(predictions):
                print(f"Missed item at: {i}")
        print(f"Predicted {len(predictions)} untagged units.")
        now = datetime.now()
        current_time = now.strftime("%H:%M:%S")
        print("Time completed: ", current_time)
        print("*****************")
        # update predictions in df and save to file
        df_untagged.tagged = 2
        df_untagged.score = predictions
        df = pd.concat([df_tagged, df_untagged])
        df.to_csv(f"results_machine_learning/{filename}_tagged.csv")
        print("File saved.")
        now = datetime.now()
        current_time = now.strftime("%H:%M:%S")
        print("Time completed: ", current_time)
        print("*****************")


if __name__ == "__main__":
    #main()
    # save output to file
    orig = sys.stdout
    with open("output.txt", "w") as f:
        sys.stdout = f
        try:
            main()
        finally:
            sys.stdout = orig